import colorsys
import os
import time
import cv2
import numpy as np
import tensorflow as tf
from tensorflow.keras import backend as K
from tensorflow.keras.applications.imagenet_utils import preprocess_input
from PIL import ImageDraw, ImageFont, Image

from nets.ssd import SSD300
from utils.utils_bbox import BBoxUtility
from utils.utils import get_classes, resize_image, cvtColor
from utils.anchors import get_anchors

'''
训练自己的数据集必看！
'''


class SSD(object):
    _defaults = {
        # --------------------------------------------------------------------------#
        #   使用自己训练好的模型进行预测一定要修改model_path和classes_path！
        #   model_path指向logs文件夹下的权值文件，classes_path指向model_data下的txt
        #
        #   训练好后logs文件夹下存在多个权值文件，选择验证集损失较低的即可。
        #   验证集损失较低不代表mAP较高，仅代表该权值在验证集上泛化性能较好。
        #   如果出现shape不匹配，同时要注意训练时的model_path和classes_path参数的修改
        # --------------------------------------------------------------------------#
        "model_path": 'logs/ep092-loss2.261-val_loss2.360.h5',
        "classes_path": 'model_data/face_classs.txt',
        # ---------------------------------------------------------------------#
        #   用于预测的图像大小，和train时使用同一个即可
        # ---------------------------------------------------------------------#
        "input_shape": [300, 300],
        # ---------------------------------------------------------------------#
        #   只有得分大于置信度的预测框会被保留下来
        # ---------------------------------------------------------------------#
        "confidence": 0.5,
        # ---------------------------------------------------------------------#
        #   非极大抑制所用到的nms_iou大小
        # ---------------------------------------------------------------------#
        "nms_iou": 0.45,
        # ---------------------------------------------------------------------#
        #   用于指定先验框的大小
        # ---------------------------------------------------------------------#
        'anchors_size': [30, 60, 111, 162, 213, 264, 315],
        # ---------------------------------------------------------------------#
        #   该变量用于控制是否使用letterbox_image对输入图像进行不失真的resize，
        #   在多次测试后，发现关闭letterbox_image直接resize的效果更好
        # ---------------------------------------------------------------------#
        "letterbox_image": False,
    }

    @classmethod
    def get_defaults(cls, n):
        if n in cls._defaults:
            return cls._defaults[n]
        else:
            return "Unrecognized attribute name '" + n + "'"

    # ---------------------------------------------------#
    #   初始化ssd
    # ---------------------------------------------------#
    def __init__(self, **kwargs):
        self.__dict__.update(self._defaults)
        for name, value in kwargs.items():
            setattr(self, name, value)
        # ---------------------------------------------------#
        #   计算总的类的数量
        # ---------------------------------------------------#
        self.class_names, self.num_classes = get_classes(self.classes_path)
        self.anchors = get_anchors(self.input_shape, self.anchors_size)
        self.num_classes = self.num_classes + 1

        # ---------------------------------------------------#
        #   画框设置不同的颜色
        # ---------------------------------------------------#
        hsv_tuples = [(x / self.num_classes, 1., 1.) for x in range(self.num_classes)]
        self.colors = list(map(lambda x: colorsys.hsv_to_rgb(*x), hsv_tuples))
        self.colors = list(map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)), self.colors))

        self.bbox_util = BBoxUtility(self.num_classes, nms_thresh=self.nms_iou)
        self.generate()

    # ---------------------------------------------------#
    #   载入模型
    # ---------------------------------------------------#
    def generate(self):
        model_path = os.path.expanduser(self.model_path)
        assert model_path.endswith('.h5'), 'Keras model or weights must be a .h5 file.'

        # -------------------------------#
        #   载入模型与权值
        # -------------------------------#
        self.ssd = SSD300([self.input_shape[0], self.input_shape[1], 3], self.num_classes)
        self.ssd.load_weights(self.model_path, by_name=True)
        print('{} model, anchors, and classes loaded.'.format(model_path))

    @tf.function
    def get_pred(self, photo):
        preds = self.ssd(photo, training=False)
        return preds

    # ---------------------------------------------------#
    #   检测图片
    # ---------------------------------------------------#
    def detect_image(self, image):
        image_shape = np.array(np.shape(image)[0:2])
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = Image.fromarray(np.uint8(image))
        # ---------------------------------------------------------#
        #   在这里将图像转换成RGB图像，防止灰度图在预测时报错。
        #   代码仅仅支持RGB图像的预测，所有其它类型的图像都会转化成RGB
        # ---------------------------------------------------------#
        image = cvtColor(image)
        # ---------------------------------------------------------#
        #   给图像增加灰条，实现不失真的resize
        #   也可以直接resize进行识别
        # ---------------------------------------------------------#
        image_data = resize_image(image, (self.input_shape[1], self.input_shape[0]), self.letterbox_image)
        # ---------------------------------------------------------#
        #   添加上batch_size维度，图片预处理，归一化。
        # ---------------------------------------------------------#
        image_data = preprocess_input(np.expand_dims(np.array(image_data, dtype='float32'), 0))

        preds = self.get_pred(image_data).numpy()
        # -----------------------------------------------------------#
        #   将预测结果进行解码
        # -----------------------------------------------------------#
        results = self.bbox_util.decode_box(preds, self.anchors, image_shape,
                                            self.input_shape, self.letterbox_image, confidence=self.confidence)
        # --------------------------------------#
        #   如果没有检测到物体，则返回原图
        # --------------------------------------#
        output = []
        if len(results[0]) <= 0:
            output = []
            return output

        top_label = np.array(results[0][:, 4], dtype='int32')
        top_conf = results[0][:, 5]
        top_boxes = results[0][:, :4]
        # ---------------------------------------------------------#
        #   图像绘制
        # ---------------------------------------------------------#

        for i, c in list(enumerate(top_label)):
            predicted_class = self.class_names[int(c)]
            box = top_boxes[i]
            score = top_conf[i]

            top, left, bottom, right = box

            top = max(0, np.floor(top).astype('int32'))
            left = max(0, np.floor(left).astype('int32'))
            bottom = min(image.size[1], np.floor(bottom).astype('int32'))
            right = min(image.size[0], np.floor(right).astype('int32'))

            label = '{} {:.2f}'.format(predicted_class, score)
            # print(label, top, left, bottom, right)
            output.append([label, top, left, bottom, right])


        return output

    def get_FPS(self, image, test_interval):
        image_shape = np.array(np.shape(image)[0:2])
        # ---------------------------------------------------------#
        #   在这里将图像转换成RGB图像，防止灰度图在预测时报错。
        #   代码仅仅支持RGB图像的预测，所有其它类型的图像都会转化成RGB
        # ---------------------------------------------------------#
        image = cvtColor(image)
        # ---------------------------------------------------------#
        #   给图像增加灰条，实现不失真的resize
        #   也可以直接resize进行识别
        # ---------------------------------------------------------#
        image_data = resize_image(image, (self.input_shape[1], self.input_shape[0]), self.letterbox_image)
        # ---------------------------------------------------------#
        #   添加上batch_size维度，图片预处理，归一化。
        # ---------------------------------------------------------#
        image_data = preprocess_input(np.expand_dims(np.array(image_data, dtype='float32'), 0))

        preds = self.get_pred(image_data).numpy()
        # -----------------------------------------------------------#
        #   将预测结果进行解码
        # -----------------------------------------------------------#
        results = self.bbox_util.decode_box(preds, self.anchors, image_shape,
                                            self.input_shape, self.letterbox_image, confidence=self.confidence)
        t1 = time.time()
        for _ in range(test_interval):
            preds = self.get_pred(image_data).numpy()
            # -----------------------------------------------------------#
            #   将预测结果进行解码
            # -----------------------------------------------------------#
            results = self.bbox_util.decode_box(preds, self.anchors, image_shape,
                                                self.input_shape, self.letterbox_image, confidence=self.confidence)
        t2 = time.time()
        tact_time = (t2 - t1) / test_interval
        return tact_time

    def get_map_txt(self, image_id, image, class_names, map_out_path):
        f = open(os.path.join(map_out_path, "detection-results/" + image_id + ".txt"), "w")
        image_shape = np.array(np.shape(image)[0:2])
        # ---------------------------------------------------------#
        #   在这里将图像转换成RGB图像，防止灰度图在预测时报错。
        #   代码仅仅支持RGB图像的预测，所有其它类型的图像都会转化成RGB
        # ---------------------------------------------------------#
        image = cvtColor(image)
        # ---------------------------------------------------------#
        #   给图像增加灰条，实现不失真的resize
        #   也可以直接resize进行识别
        # ---------------------------------------------------------#
        image_data = resize_image(image, (self.input_shape[1], self.input_shape[0]), self.letterbox_image)
        # ---------------------------------------------------------#
        #   添加上batch_size维度，图片预处理，归一化。
        # ---------------------------------------------------------#
        image_data = preprocess_input(np.expand_dims(np.array(image_data, dtype='float32'), 0))

        preds = self.get_pred(image_data).numpy()
        # -----------------------------------------------------------#
        #   将预测结果进行解码
        # -----------------------------------------------------------#
        results = self.bbox_util.decode_box(preds, self.anchors, image_shape,
                                            self.input_shape, self.letterbox_image, confidence=self.confidence)
        # --------------------------------------#
        #   如果没有检测到物体，则返回原图
        # --------------------------------------#
        if len(results[0]) <= 0:
            return

        top_label = results[0][:, 4]
        top_conf = results[0][:, 5]
        top_boxes = results[0][:, :4]

        for i, c in list(enumerate(top_label)):
            predicted_class = self.class_names[int(c)]
            box = top_boxes[i]
            score = str(top_conf[i])

            top, left, bottom, right = box

            if predicted_class not in class_names:
                continue

            f.write("%s %s %s %s %s %s\n" % (
            predicted_class, score[:6], str(int(left)), str(int(top)), str(int(right)), str(int(bottom))))

        f.close()
        return
